#! /usr/bin/env python
## This file is part of biopy.
## Copyright (C) 2010 Joseph Heled
## Author: Joseph Heled <jheled@gmail.com>
## See the files gpl.txt and lgpl.txt for copying conditions.

from __future__ import division

import argparse, sys, os.path

parser = argparse.ArgumentParser(description= """Simulation sequences for tips of a
gene tree using a substitution and clock models.""")


parser.add_argument("-n", "--seqlen", metavar="N", type = int,
                    help="""Alignment length. (default %(default)d))""", default = 400) 

parser.add_argument("-m", "--model", action="append", help =
"""Substitution model. Comma separated list - model name and model parameters.
(1) JC,mu (2) HKY,mu,kappa,freqs (3) GTR,mu,rates,freqs (default JC,1)""")

parser.add_argument("-c", "--clock", action="append", metavar="DIST", help =
"""Molecular clock. The mutation rate of each branch is multipled of a random value
drawn from DIST. A number by itself specifies a constant. Otherwise, a comma
separated list where the first character specifies the distribution (see the
documentation for supported distributions). The base branch mutation rate is the one
specified by the model, but may be overridden by an attribute specified in the tree.""",
                    default = None)

parser.add_argument("-a", "--annotate", metavar="FILE",
                    help="Output trees with sequences embedded in tips annotation. By default,"  
                    " generate one NEXUS alignment file per tree.",
                    default = None)

parser.add_argument("-o", "--nexus", dest="nexfile", metavar="FILE",
                    help="Name of annotated trees file or base of alignment file.", default = None)


parser.add_argument('trees', metavar='FILE-OR-TREE', help="Trees file (or a NEICK tree string)")

import re

from biopy import INexus, submodels, __version__
from biopy.treeutils import toNewick, TreeLogger
from biopy.randomDistributions import parseDistribution

options = parser.parse_args()

## if len(args) == 0 :
##   parser.print_help(sys.stderr)
##   sys.exit(1)

#if options.annotate and not options.nexfile :
#  print >> sys.stderr, "Please specify a file name with -a."
#  parser.print_help(sys.stderr)
#  sys.exit(1)

targ = options.trees
if os.path.isfile(targ) :
  trees = list(INexus.INexus().read(file(targ)))
else :
  if targ[0] != '(' :
    print >> sys.stderr, "Command line argument is not a file or a newick tree"
    sys.exit(1)
    
  try :
    trees = [INexus.Tree(targ)]
  except Exception,e:
    print >> sys.stderr, "*** Error: failed to parse tree: ",e.message,"[",targ,"]"
    sys.exit(1)

seqLen = options.seqlen
if seqLen <= 0 :
  parser.print_help(sys.stderr)
  sys.exit(1)

clocksDist = []
if  options.clock :
  for clock in options.clock :
    pat = None
    if clock.find(':') > 0 :
      clock,pat = model.split(':')
    try :
      clockDist = parseDistribution(clock)
    except RuntimeError as e:
      print >> sys.stderr, "*** Error: parsing clock distribution", e.message
      sys.exit(1)
    clocksDist.append([clockDist,pat])

if clocksDist :
  clocksDist = sorted(clocksDist, key = lambda x: 0 if x[1] is None else 1)
  
  if len([p is None for c,p in clocksDist]) > 1 :
    print >> sys.stderr, """Only one clock distribution can be the default (catch-all), the rest
 require a pattern."""
  
    sys.exit(1)

def parseFreqs(freqs) :
  pi = [0]*4
  qMat = [1]*7
  
  while len(freqs) >= 2 :
    code,f = freqs[:2]

    try:
      f = float(f)
    except :
      raise ValueError(f)
    
    freqs = freqs[2:]
    
    if len(code) == 1 :
      c = submodels.SubstitutionModel.NUCS.index(code)
      if not 0 <= c < 4 :
        raise RuntimeError("Can't parse nucleotide frequency" + ",".join(freqs))
      if not 0 < f < 1 :
        raise RuntimeError("Illegal nucleotide frequency: " + f)
      pi[c] = f
    elif len(code) == 2 :
      c1,c2 = sorted([submodels.SubstitutionModel.NUCS.index(c) for c in code])
      if c1 == c2 :
        raise RuntimeError("Illegal nucleotide specification: ", code)
      if f <= 0 :
        raise RuntimeError("Illegal rate: " + f)
      qMat[2*c1 + c2-1] = f
      
  notAssigned = sum([x == 0 for x in pi])
  if notAssigned > 0 :
    if sum(pi) >= 1 :
      raise("Illegal nucleotide stationary frequencies: sum to more than 1")
    p = (1 - sum(pi))/notAssigned
    for k in range(4) :
      if pi[k] == 0 :
        pi[k] = p
  else :
    if sum(pi) != 1 :
      print >> sys.stderr, "Nucleotide frequencies do not sum to 1 - normalizing..."
      pi = [x/sum(pi) for x in pi]

  return pi,[x/qMat[-1] for x in qMat][:-1]

#print options.model
if not options.model :
  options.model = [["JC,1",None]]
  
models = []
for model in options.model:
  pat = None
  if model.find(':') > 0 :
    model,pat = model.split(':')
  #print model
  
  try :
    model = model.split(',')
    modelName = model[0].upper()
    if modelName == 'JC' :
      mu = float(model[1]) if model[1] else 1
      smodel = submodels.JCSubstitutionModel(mu)
    elif modelName == 'HKY':
      mu = float(model[1]) if model[1] else 1
      kappa = float(model[2]) if model[2] else 1
      freqs = model[3:]
      pi = parseFreqs(freqs)[0]
      smodel = submodels.HKYSubstitutionModel(mu = mu, kappa = kappa, pi = pi)
    elif modelName == 'GTR':
      mu = float(model[1]) if model[1] else 1
      freqs = model[2:]
      pi,qMat = parseFreqs(freqs)
      smodel = submodels.StationaryGTR(mu = mu, m = qMat, pi = pi)
    else :
      print >> sys.stderr,"model",modelName,"??"
      sys.exit(1)
  except RuntimeError as r:
    print >> sys.stderr, "error in parsing model." + r.message
    sys.exit(1)

  models.append([smodel,pat])
  del smodel
  
if models[0][1] != None or any([p is None for m,p in models[1:]]) :
  print >> sys.stderr, """First model must be the default (catch-all), the rest
 require a pattern."""
  
  sys.exit(1)
  
tlog = None
if options.annotate :
  tlog = TreeLogger(options.annotate, argv = sys.argv, version = __version__)

from biopy.INexus import exportMatrix
  
for count,tree in enumerate(trees):
  if len(models) == 1 or not tree.name:
    smodel = models[0][0]
  else :
    smodel = None
    for m,p in models[1:]:
      if re.search(p, tree.name) :
        smodel = m
        break
    smodel = smodel or models[0][0]

  clockDist = None
  if len(clocksDist) :
    for c,p in clocksDist[1:]:
      if re.search(p, tree.name) :
        clockDist = c
        break
    if not clockDist and clocksDist[0][1] is None:
      clockDist = clocksDist[0][0]
      
  smodel.populateTreeSeqBySimulation(tree, seqLen, clockDist)

  for n in tree.get_terminals() :
    data = tree.node(n).data
    if not hasattr(data, "attributes") :
      data.attributes = dict()
    data.attributes['seq'] = smodel.toNucCode(data.seq)

  if tlog :
    if len(models) > 1 :
      d = tree.node(tree.root).data
      if not hasattr(d,"attributes") :
        d.attributes = dict()
      d.attributes['mu'] = str(smodel.mu)
      
    treeTxt = toNewick(tree, attributes = 'attributes')
    tlog.outTree(treeTxt, name = tree.name)

  if options.nexfile or not tlog:
    base = tree.name or options.nexfile or "alignment"
    if len(trees) > 1 and not tree.name :
      fname = "%s%d" % (base, count)
    else :
      fname = base
      
    if not fname.endswith(".nex") :
      fname = fname + ".nex"
      
    d = dict()
    for n in tree.get_terminals() :
      data = tree.node(n).data
      d[data.taxon] = data.attributes['seq']

    f = file(fname, "w")
    print >> f, "[Generated by %s%s]" % \
          (os.path.basename(sys.argv[0]),
           (", version " + __version__) if __version__ is not None else "")
    print >> f, "[%s]" % " ".join(sys.argv)

    exportMatrix(f, d)
    f.close()
    
if tlog:
  tlog.close()

